[WIP] Add Flax diverse group search#24508
Conversation
sanchit-gandhi
left a comment
There was a problem hiding this comment.
This look promising already @yeandy! Left some comments regarding the design below. In addition, could we add a few tests to confirm that:
- Group beam search runs when we call
model.generate - That group beam search is jit'able
- And that we get equivalence with PyTorch
| trace: bool = True, | ||
| params: Optional[Dict[str, jnp.ndarray]] = None, | ||
| num_return_sequences: Optional[int] = None, | ||
| num_beam_groups: Optional[int] = 1, |
There was a problem hiding this comment.
In PyTorch we define a separate beam search method for group beam search:
transformers/src/transformers/generation/utils.py
Line 3375 in 33b5ef5
We only trigger this method if num_beam_groups>1:
transformers/src/transformers/generation/utils.py
Line 1426 in 33b5ef5
My opinion is that we should have a separate group beam search method in Flax as well, rather than adding to the existing one. IMO this is cleaner for the reader and more compartmentalised for building on
cc @gante as well for Flax generate design decision
There was a problem hiding this comment.
Thanks @sanchit-gandhi!
My first commit was to get a prototype working for num_beam_groups=1. I intend to refactor the beam search logic to make sure it works for other num_beam_groups sizes.
- Will do.
- My current logic is jittable, as I've been doing some testing from this example. Are there test in the HF repo that explicitly test whether a function is jittable? Or is sufficient to have an E2E test jits the function?
- Will do.
There was a problem hiding this comment.
My opinion is that we should have a separate group beam search method in Flax as well, rather than adding to the existing one.
+1 :)
(btw, there was a recent bugfix on the PT side, might be relevant here)
There was a problem hiding this comment.
Awesome, sounds good @yeandy! Excited to see how this pans out!
| add_penalty = ~did_topk_just_finished | beams_in_batch_are_full | ||
| topk_log_probs += add_penalty * np.array(-1.0e7) | ||
|
|
||
| # Add additional logic for diverse beam search |
There was a problem hiding this comment.
Nice! My only nit is that we try and avoid lambda functions in transformers - would you be able to re-write these as standard function definitions please?
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
|
Hey @yeandy! This PR is looking in good shape - thanks for your efforts so far! Would you like to go all the way and see it to completion? Happy to help with the remainder of the integration! |
|
Hey @sanchit-gandhi. Due to other commitments, I currently don't have bandwidth to continue this. And the timeline for me to get to this unknown right now. If someone else wants to work on this, I'm ok with that. |
|
Thanks for letting me know @yeandy! Best of luck with your other commitments, I hope they go well 🤗 Opening this one up to the community to complete! |
|
For those who wonder what the status is for this PR, it seems all TF/Flax support has been deprecated. So this PR is no longer in scope. |
|
Yes, this should have been closed long ago! |
What does this PR do?
Mimics #9006, but for Flax.
We want to match how PyTorch's logic accounts for
group_sizeandnum_beam_groupshere and hereFixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.